package org.apache.seatunnel.connectors.seatunnel.gaussdbmongodb.serde;

import org.apache.seatunnel.api.table.type.*;
import org.apache.seatunnel.common.exception.CommonErrorCode;
import org.apache.seatunnel.connectors.seatunnel.gaussdbmongodb.exception.MongodbConnectorException;
import org.bson.*;
import org.bson.json.JsonParseException;
import org.bson.types.Decimal128;

import java.io.Serializable;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.*;
import java.util.stream.Collectors;

import static org.apache.seatunnel.api.table.type.SqlType.NULL;
import static org.apache.seatunnel.common.exception.CommonErrorCode.UNSUPPORTED_DATA_TYPE;
import static org.apache.seatunnel.connectors.seatunnel.gaussdbmongodb.config.MongodbConfig.ENCODE_VALUE_FIELD;
import static org.apache.seatunnel.connectors.seatunnel.gaussdbmongodb.serde.BsonToRowDataConverters.fromBigDecimal;
public class RowDataToBsonConverters implements Serializable {

    private static final long serialVersionUID = 1L;

    @FunctionalInterface
    public interface RowDataToBsonConverter extends Serializable {
        BsonDocument convert(SeaTunnelRow rowData);
    }

    public static RowDataToBsonConverter createConverter(SeaTunnelDataType<?> type) {
        SerializableFunction<Object, BsonValue> internalRowConverter =
                createNullSafeInternalConverter(type);
        return new RowDataToBsonConverter() {
            private static final long serialVersionUID = 1L;

            @Override
            public BsonDocument convert(SeaTunnelRow rowData) {
                if (internalRowConverter.apply(rowData) instanceof BsonDocument) {
                    return (BsonDocument) internalRowConverter.apply(rowData);
                }
                return null;
            }
        };
    }

    private static SerializableFunction<Object, BsonValue> createNullSafeInternalConverter(
            SeaTunnelDataType<?> type) {
        return wrapIntoNullSafeInternalConverter(createInternalConverter(type), type);
    }

    private static SerializableFunction<Object, BsonValue> wrapIntoNullSafeInternalConverter(
            SerializableFunction<Object, BsonValue> internalConverter, SeaTunnelDataType<?> type) {
        return new SerializableFunction<Object, BsonValue>() {
            private static final long serialVersionUID = 1L;

            @Override
            public BsonValue apply(Object value) {
                if (value == null || NULL.equals(type.getSqlType())) {
                    throw new MongodbConnectorException(
                            UNSUPPORTED_DATA_TYPE,
                            "The column type is <"
                                    + type
                                    + ">, but a null value is being written into it");
                } else {
                    return internalConverter.apply(value);
                }
            }
        };
    }

    private static SerializableFunction<Object, BsonValue> createInternalConverter(
            SeaTunnelDataType<?> type) {
        switch (type.getSqlType()) {
            case NULL:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        return BsonNull.VALUE;
                    }
                };
            case BOOLEAN:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        return new BsonBoolean((boolean) value);
                    }
                };
            case TINYINT:
            case SMALLINT:
            case INT:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        int intValue =
                                value instanceof Byte
                                        ? ((Byte) value) & 0xFF
                                        : value instanceof Short
                                                ? ((Short) value).intValue()
                                                : (int) value;
                        return new BsonInt32(intValue);
                    }
                };
            case BIGINT:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        return new BsonInt64((long) value);
                    }
                };
            case FLOAT:
            case DOUBLE:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        double v =
                                value instanceof Float
                                        ? ((Float) value).doubleValue()
                                        : (double) value;
                        return new BsonDouble(v);
                    }
                };
            case STRING:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        String val = value.toString();
                        // try to parse out the mongodb specific data type from extend-json.
                        if (val.startsWith("{")
                                && val.endsWith("}")
                                && val.contains(ENCODE_VALUE_FIELD)) {
                            try {
                                BsonDocument doc = BsonDocument.parse(val);
                                if (doc.containsKey(ENCODE_VALUE_FIELD)) {
                                    return doc.get(ENCODE_VALUE_FIELD);
                                }
                            } catch (JsonParseException e) {
                                // invalid json format, fallback to store as a bson string.
                                return new BsonString(value.toString());
                            }
                        }
                        return new BsonString(value.toString());
                    }
                };
            case BYTES:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        return new BsonBinary((byte[]) value);
                    }
                };
            case DATE:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        if (value instanceof LocalDate) {
                            LocalDate localDate = (LocalDate) value;
                            return new BsonDateTime(
                                    localDate
                                            .atStartOfDay(ZoneId.systemDefault())
                                            .toInstant()
                                            .toEpochMilli());
                        }
                        return null;
                    }
                };
            case TIMESTAMP:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        if (value instanceof LocalDateTime) {
                            LocalDateTime localDateTime = (LocalDateTime) value;
                            return new BsonDateTime(
                                    localDateTime
                                            .atZone(ZoneId.systemDefault())
                                            .toInstant()
                                            .toEpochMilli());
                        }
                        return null;
                    }
                };
            case DECIMAL:
                return new SerializableFunction<Object, BsonValue>() {
                    private static final long serialVersionUID = 1L;

                    @Override
                    public BsonValue apply(Object value) {
                        if (value instanceof BigDecimal && type instanceof DecimalType) {
                            DecimalType decimalType = (DecimalType) type;
                            BigDecimal decimalVal = (BigDecimal) value;

                            Optional<BigDecimal> optionalValue = fromBigDecimal(decimalVal, decimalType.getPrecision(), decimalType.getScale());
                            return optionalValue.map(item ->
                                            new BsonDecimal128(new Decimal128(item)))
                                    .orElseThrow(() -> new IllegalArgumentException("Invalid BigDecimal value"));
                        }
                        return null;
                    }
                };
            case ARRAY:
                return createArrayConverter((ArrayType<?, ?>) type);
            case MAP:
                MapType<?, ?> mapType = (MapType<?, ?>) type;
                return createMapConverter(
                        mapType.toString(), mapType.getKeyType(), mapType.getValueType());
            case ROW:
                if (type instanceof SeaTunnelRowType)
                    return createRowConverter((SeaTunnelRowType) type);
            default:
                throw new MongodbConnectorException(
                        UNSUPPORTED_DATA_TYPE, "Not support to parse type: " + type);
        }
    }

    private static SerializableFunction<Object, BsonValue> createArrayConverter(
            ArrayType<?, ?> arrayType) {
        final SerializableFunction<Object, BsonValue> elementConverter =
                createNullSafeInternalConverter(arrayType.getElementType());

        return new SerializableFunction<Object, BsonValue>() {
            private static final long serialVersionUID = 1L;

            @Override
            public BsonValue apply(Object value) {
                List<Object> listData = Arrays.asList((Object[]) value);
                List<BsonValue> bsonValues = new ArrayList<>();
                for (Object element : listData) {
                    bsonValues.add(elementConverter.apply(element));
                }
                return new BsonArray(bsonValues);
            }
        };
    }

    private static SerializableFunction<Object, BsonValue> createMapConverter(
            String typeSummary, SeaTunnelDataType<?> keyType, SeaTunnelDataType<?> valueType) {
        if (!SqlType.STRING.equals(keyType.getSqlType())) {
            throw new MongodbConnectorException(
                    CommonErrorCode.UNSUPPORTED_OPERATION,
                    "JSON format doesn't support non-string as key type of map. The type is: "
                            + typeSummary);
        }

        final SerializableFunction<Object, BsonValue> valueConverter =
                createNullSafeInternalConverter(valueType);

        return new SerializableFunction<Object, BsonValue>() {
            private static final long serialVersionUID = 1L;

            @Override
            public BsonValue apply(Object value) {
                Map<String, ?> mapData = (Map<String, ?>) value;
                final BsonDocument document = new BsonDocument();
                for (Map.Entry<String, ?> entry : mapData.entrySet()) {
                    String fieldName = entry.getKey();
                    document.append(fieldName, valueConverter.apply(entry.getValue()));
                }
                return document;
            }
        };
    }

    private static SerializableFunction<Object, BsonValue> createRowConverter(
            SeaTunnelRowType rowType) {
        List<SerializableFunction<Object, BsonValue>> fieldConverters =
                rowType.getChildren().stream()
                        .map(RowDataToBsonConverters::createNullSafeInternalConverter)
                        .collect(Collectors.toList());

        final int fieldCount = rowType.getTotalFields();
        final String[] fieldNames = rowType.getFieldNames();

        return new SerializableFunction<Object, BsonValue>() {
            private static final long serialVersionUID = 1L;

            @Override
            public BsonValue apply(Object value) {
                SeaTunnelRow rowData = null;
                if (value instanceof SeaTunnelRow) rowData = (SeaTunnelRow) value;
                final BsonDocument document = new BsonDocument();
                for (int i = 0; i < fieldCount; i++) {
                    document.append(fieldNames[i], fieldConverters.get(i).apply(rowData.getField(i)));
                }
                return document;
            }
        };
    }
}
